{% extends 'base.exported.class' %}

{% block content %}
package main

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"os"
	"strconv"
)

{% if is_test or to_json %}
type Response struct {
	Predict int `json:"predict"`
	PredictProba []float64 `json:"predict_proba"`
}

{% endif %}
type {{ class_name }} struct {
	Trees []{{ class_name }}Tree
}

type {{ class_name }}Tree struct {
	Lefts []int `json:"lefts"`
	Rights []int `json:"rights"`
	Thresholds []float64 `json:"thresholds"`
	Indices []int `json:"indices"`
	Classes [][]float64 `json:"classes"`
}

func (forest {{ class_name }}) FindMax(nums []float64) int {
	var idx = 0
	for i := 0; i < len(nums); i++ {
		if nums[i] > nums[idx] {
			idx = i
		}
	}
	return idx
}

func (tree {{ class_name }}Tree) NormVals(nums []float64) []float64 {
	var nNums = len(nums)
	var result = make([]float64, nNums)
	var sum float64
	for i := 0; i < nNums; i++ {
		sum += nums[i]
	}
	if sum == 0 {
		for i := 0; i < nNums; i++ {
			result[i] = 1.0 / float64(nNums)
		}
	} else {
		for i := 0; i < nNums; i++ {
			result[i] = nums[i] / sum
		}
	}
	return result
}

func (tree {{ class_name }}Tree) Compute(features []float64, node int) []float64 {
	for {
		if tree.Thresholds[node] != -2 {
			if features[tree.Indices[node]] <= tree.Thresholds[node] {
				node = tree.Lefts[node]
			} else {
				node = tree.Rights[node]
			}
		} else {
			return tree.NormVals(tree.Classes[node])
		}
	}
}

func (forest {{ class_name }}) Predict(features []float64) int {
	return forest.FindMax(forest.PredictProba(features))
}

func (forest {{ class_name }}) PredictProba(features []float64) []float64 {
	var nClasses = len(forest.Trees[0].Classes[0])
	var nTrees = len(forest.Trees)
	var classProbas = make([]float64, nClasses)
	for _, tree := range forest.Trees {
		var tmp = tree.Compute(features, 0)
		for i := 0; i < nClasses; i++ {
			classProbas[i] += tmp[i]
		}
	}
	for i := 0; i < nClasses; i++ {
		classProbas[i] /= float64(nTrees)
	}
	return classProbas
}

func main() {

	// Features:
	var features []float64
	for _, arg := range os.Args[2:] {
		if n, err := strconv.ParseFloat(arg, 64); err == nil {
			features = append(features, n)
		}
	}

	// Model data:
	jsonFile, err := os.Open(os.Args[1])
	if err != nil {
		fmt.Println(err)
	}
	defer jsonFile.Close()
	byteValue, _ := ioutil.ReadAll(jsonFile)

	// Estimator:
	var trees []{{ class_name }}Tree
	json.Unmarshal(byteValue, &trees)
	var clf {{ class_name }}
	clf.Trees = trees

	{% if is_test or to_json %}
	// Get JSON:
	prediction := clf.Predict(features)
	probabilities := clf.PredictProba(features)
	res, _ := json.Marshal(&Response{Predict: prediction, PredictProba: probabilities})
	fmt.Println(string(res))
	{% else %}
	// Get class prediction:
	prediction := clf.Predict(features)
	fmt.Printf("Predicted class: #%d\n", prediction)
	probabilities := clf.PredictProba(features)
	for i, probability := range probabilities {
		fmt.Printf("Probability of class #%d : %f\n", i, probability)
	}
	{% endif %}

}
{% endblock %}
